<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# How to accelerate training with ONNX Runtime

Optimum integrates ONNX Runtime Training through an `ORTTrainer` API that extends `Trainer` in [Transformers](https://huggingface.co/docs/transformers/index).
With this extension, training time can be reduced by more than 35% for many popular Hugging Face models compared to PyTorch under eager mode.

[`ORTTrainer`] and [`ORTSeq2SeqTrainer`] APIs make it easy to compose __[ONNX Runtime (ORT)](https://onnxruntime.ai/)__ with other features in [`Trainer`](https://huggingface.co/docs/transformers/main_classes/trainer).
It contains feature-complete training loop and evaluation loop, and supports hyperparameter search, mixed-precision training and distributed training with multiple [NVIDIA](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/accelerate-pytorch-transformer-model-training-with-onnx-runtime/ba-p/2540471)
and [AMD](https://cloudblogs.microsoft.com/opensource/2021/07/13/onnx-runtime-release-1-8-1-previews-support-for-accelerated-training-on-amd-gpus-with-the-amd-rocm-open-software-platform/) GPUs.
With the ONNX Runtime backend, [`ORTTrainer`] and [`ORTSeq2SeqTrainer`] take advantage of:

* Computation graph optimizations: constant foldings, node eliminations, node fusions
* Efficient memory planning
* Kernel optimization
* ORT fused Adam optimizer: batches the elementwise updates applied to all the model's parameters into one or a few kernel launches
* More efficient FP16 optimizer: eliminates a great deal of device to host memory copies
* Mixed precision training

Test it out to achieve __lower latency, higher throughput, and larger maximum batch size__ while training models in 🤗 Transformers!

## Performance

The chart below shows impressive acceleration __from 39% to 130%__ for Hugging Face models with Optimum when __using ONNX Runtime and DeepSpeed ZeRO Stage 1__ for training.
The performance measurements were done on selected Hugging Face models with PyTorch as the baseline run, only ONNX Runtime for training as the second run, and ONNX
Runtime + DeepSpeed ZeRO Stage 1 as the final run, showing maximum gains. The Optimizer used for the baseline PyTorch runs is the AdamW optimizer and the ORT Training
runs use the Fused Adam Optimizer(available in `ORTTrainingArguments`). The runs were performed on a single Nvidia A100 node with 8 GPUs.

<figure class="image table text-center m-0 w-full">
  <img src="https://huggingface.co/datasets/optimum/documentation-images/resolve/main/onnxruntime/onnxruntime-training-benchmark.png" alt="ONNX Runtime Training Benchmark"/>
</figure>

The version information used for these runs is as follows:
```
PyTorch: 1.14.0.dev20221103+cu116; ORT: 1.14.0.dev20221103001+cu116; DeepSpeed: 0.6.6; HuggingFace: 4.24.0.dev0; Optimum: 1.4.1.dev0; Cuda: 11.6.2
```

## Start by setting up the environment

To use ONNX Runtime for training, you need a machine with at least one NVIDIA or AMD GPU.

To use `ORTTrainer` or `ORTSeq2SeqTrainer`, you need to install ONNX Runtime Training module and Optimum.

### Install ONNX Runtime

To set up the environment, we __strongly recommend__ you install the dependencies with Docker to ensure that the versions are correct and well
configured. You can find dockerfiles with various combinations [here](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training/docker).

#### Setup for NVIDIA GPU

Here below we take the installation of `onnxruntime-training 1.14.0` as an example:

* If you want to install `onnxruntime-training 1.14.0` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort1.14.0-cu116):

```bash
docker build -f Dockerfile-ort1.14.0-cu116 -t ort/train:1.14.0 .
```

* If you want to install the dependencies beyond in a local Python environment. You can pip install them once you have [CUDA 11.6](https://docs.nvidia.com/cuda/archive/11.6.2/) and [cuDNN 8](https://developer.nvidia.com/cudnn) well installed.

```bash
pip install onnx ninja
pip install torch==1.13.1+cu116 torchvision==0.14.1 -f https://download.pytorch.org/whl/cu116/torch_stable.html
pip install onnxruntime-training==1.14.0 -f https://download.onnxruntime.ai/onnxruntime_stable_cu116.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2
```

And run post-installation configuration:

```bash
python -m torch_ort.configure
```

#### Setup for AMD GPU

Here below we take the installation of `onnxruntime-training` nightly as an example:

* If you want to install `onnxruntime-training` via [Dockerfile](https://github.com/huggingface/optimum/blob/main/examples/onnxruntime/training/docker/Dockerfile-ort-nightly-rocm57):

```bash
docker build -f Dockerfile-ort-nightly-rocm57 -t ort/train:nightly .
```

* If you want to install the dependencies beyond in a local Python environment. You can pip install them once you have [ROCM 5.7](https://rocmdocs.amd.com/en/latest/deploy/linux/quick_start.html) well installed.

```bash
pip install onnx ninja
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm5.7
pip install pip install --pre onnxruntime-training -f https://download.onnxruntime.ai/onnxruntime_nightly_rocm57.html
pip install torch-ort
pip install --upgrade protobuf==3.20.2
```

And run post-installation configuration:

```bash
python -m torch_ort.configure
```

### Install Optimum

You can install Optimum via pypi:

```bash
pip install optimum
```

Or install from source:

```bash
pip install git+https://github.com/huggingface/optimum.git
```

This command installs the current main dev version of Optimum, which could include latest developments(new features, bug fixes). However, the
main version might not be very stable. If you run into any problem, please open an [issue](https://github.com/huggingface/optimum/issues) so
that we can fix it as soon as possible.

## ORTTrainer

The [`ORTTrainer`] class inherits the [`Trainer`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#trainer)
of Transformers. You can easily adapt the codes by replacing `Trainer` of transformers with `ORTTrainer` to take advantage of the acceleration
empowered by ONNX Runtime. Here is an example of how to use `ORTTrainer` compared with `Trainer`:

```diff
-from transformers import Trainer, TrainingArguments
+from optimum.onnxruntime import ORTTrainer, ORTTrainingArguments

# Step 1: Define training arguments
-training_args = TrainingArguments(
+training_args = ORTTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Trainer
-trainer = Trainer(
+trainer = ORTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text-classification",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
```

Check out more detailed [example scripts](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training) in the optimum repository.


## ORTSeq2SeqTrainer

The [`ORTSeq2SeqTrainer`] class is similar to the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Seq2SeqTrainer)
of Transformers. You can easily adapt the codes by replacing `Seq2SeqTrainer` of transformers with `ORTSeq2SeqTrainer` to take advantage of the acceleration
empowered by ONNX Runtime. Here is an example of how to use `ORTSeq2SeqTrainer` compared with `Seq2SeqTrainer`:

```diff
-from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainer, ORTSeq2SeqTrainingArguments

# Step 1: Define training arguments
-training_args = Seq2SeqTrainingArguments(
+training_args = ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",
    ...
)

# Step 2: Create your ONNX Runtime Seq2SeqTrainer
-trainer = Seq2SeqTrainer(
+trainer = ORTSeq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
+   feature="text2text-generation",
    ...
)

# Step 3: Use ONNX Runtime for training!🤗
trainer.train()
```

Check out more detailed [example scripts](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training) in the optimum repository.


## ORTTrainingArguments

The [`ORTTrainingArguments`] class inherits the [`TrainingArguments`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.TrainingArguments)
class in Transformers. Besides the optimizers implemented in Transformers, it allows you to use the optimizers implemented in ONNX Runtime.
Replace `Seq2SeqTrainingArguments` with `ORTSeq2SeqTrainingArguments`:

```diff
-from transformers import TrainingArguments
+from optimum.onnxruntime import ORTTrainingArguments

-training_args = TrainingArguments(
+training_args =  ORTTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)
```


<Tip warning={false}>

DeepSpeed is supported by ONNX Runtime(only ZeRO stage 1 and 2 for the moment).
You can find some [DeepSpeed configuration examples](https://github.com/huggingface/optimum/tree/main/tests/onnxruntime/ds_configs)
in the Optimum repository.

</Tip>

## ORTSeq2SeqTrainingArguments

The [`ORTSeq2SeqTrainingArguments`] class inherits the [`Seq2SeqTrainingArguments`](https://huggingface.co/docs/transformers/main/en/main_classes/trainer#transformers.Seq2SeqTrainingArguments)
class in Transformers. Besides the optimizers implemented in Transformers, it allows you to use the optimizers implemented in ONNX Runtime.
Replace `Seq2SeqTrainingArguments` with `ORTSeq2SeqTrainingArguments`:


```diff
-from transformers import Seq2SeqTrainingArguments
+from optimum.onnxruntime import ORTSeq2SeqTrainingArguments

-training_args = Seq2SeqTrainingArguments(
+training_args =  ORTSeq2SeqTrainingArguments(
    output_dir="path/to/save/folder/",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="path/to/save/folder/",
-   optim = "adamw_hf",
+   optim="adamw_ort_fused",  # Fused Adam optimizer implemented by ORT
)
```

<Tip warning={false}>

DeepSpeed is supported by ONNX Runtime(only ZeRO stage 1 and 2 for the moment).
You can find some [DeepSpeed configuration examples](https://github.com/huggingface/optimum/tree/main/tests/onnxruntime/ds_configs)
in the Optimum repository.

</Tip>

## ORTModule+StableDiffusion

Optimum supports accelerating Hugging Face Diffusers with ONNX Runtime in [this example](https://github.com/huggingface/optimum/tree/main/examples/onnxruntime/training/stable-diffusion/text-to-image).
The core changes required to enable ONNX Runtime Training are summarized below:

```diff
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel

+from onnxruntime.training.ortmodule import ORTModule
+from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer

unet = UNet2DConditionModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="unet",
    ...
)
text_encoder = CLIPTextModel.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="text_encoder",
    ...
)
vae = AutoencoderKL.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    subfolder="vae",
    ...
)

optimizer = torch.optim.AdamW(
    unet.parameters(),
    ...
)

+vae = ORTModule(vae)
+text_encoder = ORTModule(text_encoder)
+unet = ORTModule(unet)

+optimizer = ORT_FP16_Optimizer(optimizer)
```

## Other Resources

* Blog posts
    * [Optimum + ONNX Runtime: Easier, Faster training for your Hugging Face models](https://huggingface.co/blog/optimum-onnxruntime-training)
    * [Accelerate PyTorch transformer model training with ONNX Runtime - a deep dive](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/accelerate-pytorch-transformer-model-training-with-onnx-runtime/ba-p/2540471)
    * [ONNX Runtime Training Technical Deep Dive](https://techcommunity.microsoft.com/t5/ai-machine-learning-blog/onnx-runtime-training-technical-deep-dive/ba-p/1398310)
* [Optimum github](https://github.com/huggingface/optimum)
* [ONNX Runtime github](https://github.com/microsoft/onnxruntime)
* [Torch ORT github](https://github.com/pytorch/ort)
* [Download ONNX Runtime stable versions](https://download.onnxruntime.ai/)

If you have any problems or questions regarding `ORTTrainer`, please file an issue with [Optimum Github](https://github.com/huggingface/optimum)
or discuss with us on [HuggingFace's community forum](https://discuss.huggingface.co/c/optimum/), cheers 🤗 !